package org.example.nebula.local

import com.facebook.thrift.protocol.TCompactProtocol
import com.vesoft.nebula.algorithm.lib.DegreeStaticAlgo
import org.apache.spark.SparkConf
import org.apache.spark.sql.expressions.Window
import org.apache.spark.sql.{DataFrame, SparkSession}
import org.apache.spark.sql.functions.{col, dense_rank}
import org.example.nebula.basic.ReadData
import org.example.nebula.local.RunAlgo.log

object DegreeLocal {

    def main(args: Array[String]): Unit = {
        val sparkConf = new SparkConf()
                .set("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
                .registerKryoClasses(Array[Class[_]](classOf[TCompactProtocol]))
        val spark = SparkSession
                .builder()
                .appName("nebula-algo")
                .master("local[1]")
                .config(sparkConf)
                .getOrCreate()

        val star = System.nanoTime()

        //读取图数据库中的图数据
        val df = ReadData.readNebulaGraphData(spark)
        df.show()

        val srcIdDF: DataFrame = df.select("_srcId").withColumnRenamed("_srcId", "id")
        val dstIdDF: DataFrame = df.select("_dstId").withColumnRenamed("_dstId", "id")
        val idDF = srcIdDF.union(dstIdDF).distinct()
        val encodeId = idDF.withColumn("encodedId", dense_rank().over(Window.orderBy("id")))
        log.warn("id编码后，id信息为：")

        val srcJoinDF = df
                .join(encodeId)
                .where(col("_srcId") === col("id"))
                .drop("_srcId")
                .drop("id")
                .withColumnRenamed("encodedId", "src")
        srcJoinDF.cache()

        val dstJoinDF = srcJoinDF
                .join(encodeId)
                .where(col("_dstId") === col("id"))
                .drop("_dstId")
                .drop("id")
                .withColumnRenamed("encodedId", "dst")

        val encodedDF = dstJoinDF.select("src", "dst")

        val pr = DegreeStaticAlgo.apply(spark, encodedDF)
        log.warn("计算结果为：")
//        pr.show()

        //将id映射回来
        val decodedPr = encodeId
                .join(pr)
                .where(col("encodedId") === col("_id"))
                .drop("encodedId")
                .drop("_id")
        decodedPr.show()

        val end = System.nanoTime()
        log.warn("消耗时间" + (end - star)/1000000 + "ms")
    }

}
